import os
import openai
import random
import numpy as np
import json
import jsonlines
import time
from tqdm import tqdm
from rank_bm25 import BM25Okapi

# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
openai.api_key = OPENAI_API_KEY

start_prompt = '''
You need to pick up some sentences from a list of captions most related to given cue. Here are five examples.
'''

def ask_gpt4(question, attempt=1):
    messages=[{"role": "user", "content": question}]
    start_time = time.time()
    while True:
        try:
            response =  openai.ChatCompletion.create(
                            model="gpt-4",
                            max_tokens=1000,
                            temperature=1.2,
                            messages = messages)
            answer = response["choices"][0]["message"]["content"]
            return answer
        except openai.error.RateLimitError: # Rate limit exceeded
            if time.time()- start_time > 20:
                print(f"Attempt {attempt}: Request timed out after 20 seconds. Retrying...")
            return ask_gpt4(question, attempt + 1)
            time.sleep(0.1)
        except openai.error.Timeout: # Rate limit exceeded
            if time.time()- start_time > 20:
                print(f"Attempt {attempt}: Request timed out after 20 seconds. Retrying...")
            return ask_gpt4(question, attempt + 1)
            time.sleep(0.1)
        except openai.error.OpenAIError:
            raise Exception("An unexpected problem occurred with OpenAI API")
        



def read_jsonline(sample_file):
    samples = []
    for line in sample_file.iter():
        sample = "The given cue is %s, the selected number of case is %d, and the captions %s. The corresponding correct labels are %s" %(line['cue'], np.count_nonzero(line['labels']), str(line['captions']), line['labels'])
        samples.append(sample)
    return samples

if __name__=="__main__":
    dataset = jsonlines.open('./data/winogavil/mmf_icl/10_12/test.jsonl')
    sample_file = jsonlines.open('./data/winogavil/mmf_icl/10_12/train.jsonl')
    corpus = read_jsonline(sample_file)
    with tqdm(desc='Process', unit='it', total=85) as pbar: #5_6: (260); 10_12: (85); swow: (84)
        with open('./gpt4_ans/winogavil/mmf_icl/10_12/test.jsonl','a') as outfile:
            num = 1 # this used to fix connection error problem 
            for line in dataset.iter():
                if num > 0: #start from 1 to reord, once code stop in middle, change number to the stop point
                    top_5 = line['mm_icl']
                    samples_prompt = ''''''
                    for id in top_5:
                        sample_prompt = '''%s''' %(corpus[id-1])
                        samples_prompt = f'''{samples_prompt}{sample_prompt}'''
                    captions = line['captions']
                    cue = line['cue']
                    labels = line['labels']
                    k = np.count_nonzero(labels)
                    Question_part1 = '''\nNow choose the top %d sentences most related to the cue %s from captions: %s. ''' % (k, cue, str(captions))
                    Question_part2 = '''Directly return the %d sentences as answer.''' %(k)
                    content = f'''{start_prompt}{samples_prompt}{Question_part1}{Question_part2}'''
                    answer = ask_gpt4(content)
                    information = {}
                    information['images'] = line['images']
                    information['cue'] = line['cue']
                    information['labels'] = line['labels']
                    information['captions'] = line['captions']
                    information['gpt4'] = answer
                    json.dump(information, outfile)
                    outfile.write('\n')
                pbar.update()
                num = num + 1